What is k-means clustering?

  • One of the most popular unsupervised classification techniques (unsupervised means there is no “right answer”- the data hasn’t been labelled)

  • Partitioning the dataset into distinct non-overlapping clusters where each datapoint belongs to only one group.

In k-means clustering, each cluster is represented by its center (i.e, centroid) which corresponds to the mean of points assigned to the cluster.

(- read something about k-means clustering being helpful for supervised learning and the example used the clusters for random forest)

  • k represents the number of groups pre-specified by the analyst

k-means Clustering Steps

  1. Specify the number of clusters (k) to be created
  2. Randomly select positions for initial clusters
  3. Assigns each observation to their closest centroid
  4. For each of the k clusters update the cluster centroid by calculating the new mean values of all the data points in the cluster.
  5. Iteratively minimize the total within sum of square (wss). That is, iterate steps 3 and 4 until the cluster assignments stop changing or the maximum number of iterations is reached.

Strengths

  • quite simple to run
  • many use cases
  • can be run on large datasets
  • easy to interpret

Weakness

  • have to input the number of clusters (k) so some knowledge of the dataset is required
  • outliers can strongly influence cluster centroids
  • can only handle numerical data
  • assumes that clusters are spherical and contain a roughly similar number of observations
  • result repeatability (can overcome this by seeding I think)

What is k-means clustering used for?

Usually applied to data with smaller number of dimensions, is numeric and continuous.

Example areas of use:
- Customer Segmentation
- Document Classification
- Identifying crime localities
- Insurance fraud detection
- Cyber-Profiling criminals

Code

https://uc-r.github.io/kmeans_clustering

library(tidyverse)  # data manipulation
library(cluster)    # clustering algorithms
library(factoextra) # clustering algorithms & visualization
library(plotly)
df <- USArrests
df <- na.omit(df)
df <- scale(df)
kmax<-10
WSSus_arrests<-sapply(1:kmax, function(k) kmeans(df, centers = k, nstart = 10)$tot.withinss)

plot(1:kmax, WSSus_arrests, type = 'b', xlab = 'k', ylab = 'Total wss')
abline(v=4, lty=2)

k4 <- kmeans(df, centers = 4, nstart = 25)
df %>%
  as_tibble() %>%
  mutate(cluster = k4$cluster,
         state = row.names(USArrests)) %>%
  ggplot(aes(UrbanPop, Murder, color = factor(cluster), label = state)) +
  geom_text()

Alaska, Arkansas and Kentucky are all quite similar for their murder and urbanpop values so they must differ on other values - lets check.

df %>%
  as_tibble() %>%
  mutate(cluster = k4$cluster,
         state = row.names(USArrests)) %>%
  ggplot(aes(Assault, Murder, color = factor(cluster), label = state)) +
  geom_text()

df <- df %>% 
  as_tibble() %>% 
  mutate(cluster = k4$cluster,
         state = row.names(USArrests))

plot_ly(df, x=~UrbanPop, y=~Murder, z=~Assault, 
             color =~ cluster, type = "scatter3d", mode = "markers", text = "state")